import torch
from segment_anything import sam_model_registry, SamPredictor


class ClothingColorExtractor:
    def __init__(self, root):
        self.root = root

        # 初始化SAM模型
        sam_checkpoint = "sam_vit_h_4b8939.pth"
        device = "cuda" if torch.cuda.is_available() else "cpu"
        sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
        sam.to(device=device)
        self.predictor = SamPredictor(sam)